//go:build sql_integration

package querymgr

import (
	"context"
	"testing"

	imageDS "github.com/stackrox/rox/central/image/datastore"
	imagePostgresV2 "github.com/stackrox/rox/central/image/datastore/store/v2/postgres"
	imageV2DS "github.com/stackrox/rox/central/imagev2/datastore"
	imageV2Postgres "github.com/stackrox/rox/central/imagev2/datastore/store/postgres"
	"github.com/stackrox/rox/central/ranking"
	mockRisks "github.com/stackrox/rox/central/risk/datastore/mocks"
	vulnReqCache "github.com/stackrox/rox/central/vulnmgmt/vulnerabilityrequest/cache"
	"github.com/stackrox/rox/central/vulnmgmt/vulnerabilityrequest/common"
	vulnReqDS "github.com/stackrox/rox/central/vulnmgmt/vulnerabilityrequest/datastore"
	"github.com/stackrox/rox/generated/storage"
	"github.com/stackrox/rox/pkg/concurrency"
	"github.com/stackrox/rox/pkg/features"
	"github.com/stackrox/rox/pkg/fixtures"
	"github.com/stackrox/rox/pkg/postgres/pgtest"
	"github.com/stackrox/rox/pkg/search"
	"github.com/stretchr/testify/suite"
	"go.uber.org/mock/gomock"
)

func TestVulnReqQueryManager(t *testing.T) {
	suite.Run(t, new(VulnReqQueryManagerTestSuite))
}

type VulnReqQueryManagerTestSuite struct {
	mockCtrl *gomock.Controller
	suite.Suite

	ctx    context.Context
	testDB *pgtest.TestPostgres

	vulnReqDataStore vulnReqDS.DataStore
	imageDataStore   imageDS.DataStore
	imageV2DataStore imageV2DS.DataStore
	mgr              *queryManagerImpl
	pendingReqCache  vulnReqCache.VulnReqCache
	activeReqCache   vulnReqCache.VulnReqCache

	testImage *storage.Image
}

func (s *VulnReqQueryManagerTestSuite) TearDownTest() {
	s.mockCtrl.Finish()

}

func (s *VulnReqQueryManagerTestSuite) SetupTest() {
	s.ctx = context.Background()
	s.mockCtrl = gomock.NewController(s.T())
	s.testDB = pgtest.ForT(s.T())

	s.pendingReqCache, s.activeReqCache = vulnReqCache.New(), vulnReqCache.New()
	s.createImageDataStore()
	s.createVulnRequestDataStore(s.pendingReqCache, s.activeReqCache)
	s.mgr = &queryManagerImpl{
		images:          s.imageDataStore,
		imageV2s:        s.imageV2DataStore,
		vulnReqs:        s.vulnReqDataStore,
		pendingReqCache: s.pendingReqCache,
		activeReqCache:  s.activeReqCache,
	}

	// Insert test image.
	s.testImage = fixtures.GetImage()
	s.NoError(s.imageDataStore.UpsertImage(allAccessCtx, s.testImage))
}

func (s *VulnReqQueryManagerTestSuite) createImageDataStore() {
	s.imageDataStore = imageDS.NewWithPostgres(
		imagePostgresV2.New(s.testDB.DB, false, concurrency.NewKeyFence()),
		mockRisks.NewMockDataStore(s.mockCtrl),
		ranking.NewRanker(),
		ranking.NewRanker(),
	)
	if features.FlattenImageData.Enabled() {
		s.imageV2DataStore = imageV2DS.NewWithPostgres(
			imageV2Postgres.New(s.testDB.DB, false, concurrency.NewKeyFence()),
			mockRisks.NewMockDataStore(s.mockCtrl),
			ranking.NewRanker(),
			ranking.NewRanker(),
		)
	}
}

func (s *VulnReqQueryManagerTestSuite) createVulnRequestDataStore(pendingReqCache vulnReqCache.VulnReqCache, activeReqCache vulnReqCache.VulnReqCache) {
	ds, err := vulnReqDS.GetTestPostgresDataStore(s.T(), s.testDB.DB, pendingReqCache, activeReqCache)
	s.Require().NoError(err)
	s.vulnReqDataStore = ds
}

func (s *VulnReqQueryManagerTestSuite) TestWithOneVulReq() {
	scope := common.VulnReqScope{
		Registry: s.testImage.GetName().GetRegistry(),
		Remote:   s.testImage.GetName().GetRemote(),
		Tag:      s.testImage.GetName().GetTag(),
	}
	cveToDefer := s.testImage.GetScan().GetComponents()[0].GetVulns()[0].GetCve()
	observedCVE := s.testImage.GetScan().GetComponents()[0].GetVulns()[1].GetCve()

	// Add a approved vulnerability request.
	req := fixtures.GetImageScopeDeferralRequest(scope.Registry, scope.Remote, scope.Tag, cveToDefer)
	req.Status = storage.RequestStatus_APPROVED
	err := s.vulnReqDataStore.AddRequest(allAccessCtx, req)
	s.NoError(err)
	s.activeReqCache.Add(req)

	// Verify that image is returned.
	img, err := s.mgr.Images(allAccessCtx, req.GetId(), nil)
	s.NoError(err)
	s.NotNil(img)
	s.Equal(s.testImage.GetId(), img[0].GetId())

	// Verify that no image is returned when queried for image that is not in DB.
	img, err = s.mgr.Images(allAccessCtx, req.GetId(), search.NewQueryBuilder().AddExactMatches(search.ImageName, "unavailable image").ProtoQuery())
	s.NoError(err)
	s.Len(img, 0)

	// Verify that vulnerability request is returned for deferred CVE.
	actualReq, err := s.mgr.EffectiveVulnReq(allAccessCtx, cveToDefer, scope)
	s.NoError(err)
	s.Equal(req.GetId(), actualReq.GetId())

	// Verify that no vulnerability request is returned for observed CVE.
	actualReq, err = s.mgr.EffectiveVulnReq(allAccessCtx, observedCVE, scope)
	s.NoError(err)
	s.Nil(actualReq)

	// Verify that the state for all the cves is correct.
	cvesWithState, err := s.mgr.VulnsWithState(allAccessCtx, scope)
	s.NoError(err)
	for cve, state := range cvesWithState {
		if cve == cveToDefer {
			s.Equal(storage.VulnerabilityState_DEFERRED, state)
		} else {
			s.Equal(storage.VulnerabilityState_OBSERVED, state)
		}
	}
}

func (s *VulnReqQueryManagerTestSuite) TestWithMultipleActiveReqs() {
	scopeInDB := common.VulnReqScope{
		Registry: s.testImage.GetName().GetRegistry(),
		Remote:   s.testImage.GetName().GetRemote(),
		Tag:      s.testImage.GetName().GetTag(),
	}
	scopeNotInDB := common.VulnReqScope{Registry: "fake", Remote: "fake", Tag: "fake"}

	cve1 := s.testImage.GetScan().GetComponents()[0].GetVulns()[0].GetCve()
	cve2 := s.testImage.GetScan().GetComponents()[0].GetVulns()[1].GetCve()

	// Add a approved vuln request.
	reqForImgInDB := fixtures.GetImageScopeDeferralRequest(scopeInDB.Registry, scopeInDB.Remote, scopeInDB.Tag, cve1)
	reqForImgInDB.Status = storage.RequestStatus_APPROVED
	err := s.vulnReqDataStore.AddRequest(allAccessCtx, reqForImgInDB)
	s.NoError(err)
	s.activeReqCache.Add(reqForImgInDB)

	// Add a approved vuln request.
	reqForImgNotInDB := fixtures.GetImageScopeDeferralRequest(scopeNotInDB.Registry, scopeNotInDB.Remote, scopeNotInDB.Tag, cve2)
	reqForImgNotInDB.Status = storage.RequestStatus_APPROVED
	err = s.vulnReqDataStore.AddRequest(allAccessCtx, reqForImgNotInDB)
	s.NoError(err)
	s.activeReqCache.Add(reqForImgNotInDB)

	// Verify that vuln request is returned for scope which has an image in DB.
	actualReq, err := s.mgr.EffectiveVulnReq(allAccessCtx, cve1, scopeInDB)
	s.NoError(err)
	s.Equal(reqForImgInDB.GetId(), actualReq.GetId())

	// Verify that vuln request is returned even though no image is in DB for the given scope.
	actualReq, err = s.mgr.EffectiveVulnReq(allAccessCtx, cve2, scopeNotInDB)
	s.NoError(err)
	s.NotNil(actualReq)

	// Verify that no vuln request is returned for a scope which has no vuln request.
	actualReq, err = s.mgr.EffectiveVulnReq(allAccessCtx, cve2, common.VulnReqScope{Registry: "invalid", Remote: "invalid", Tag: "invalid"})
	s.NoError(err)
	s.Nil(actualReq)

	// Verify that the state for all the cves is correct.
	cvesWithState, err := s.mgr.VulnsWithState(allAccessCtx, scopeInDB)
	s.NoError(err)
	for cve, state := range cvesWithState {
		if cve == cve1 || cve == cve2 {
			s.Equal(storage.VulnerabilityState_DEFERRED, state)
		} else {
			s.Equal(storage.VulnerabilityState_OBSERVED, state)
		}
	}
}

func (s *VulnReqQueryManagerTestSuite) TestWithMultipleScopes() {
	scope := common.VulnReqScope{
		Registry: s.testImage.GetName().GetRegistry(),
		Remote:   s.testImage.GetName().GetRemote(),
		Tag:      s.testImage.GetName().GetTag(),
	}
	cve1 := s.testImage.GetScan().GetComponents()[0].GetVulns()[0].GetCve()
	cve2 := s.testImage.GetScan().GetComponents()[0].GetVulns()[1].GetCve()

	// Add a approved vuln request.
	deferralReq1 := fixtures.GetImageScopeDeferralRequest(scope.Registry, scope.Remote, scope.Tag, cve1)
	deferralReq1.Status = storage.RequestStatus_APPROVED
	err := s.vulnReqDataStore.AddRequest(allAccessCtx, deferralReq1)
	s.NoError(err)
	s.activeReqCache.Add(deferralReq1)

	// Add a all tags approved vuln request.
	fpReq := fixtures.GetImageScopeFPRequest(scope.Registry, scope.Remote, ".*", cve1)
	fpReq.Status = storage.RequestStatus_APPROVED
	err = s.vulnReqDataStore.AddRequest(allAccessCtx, fpReq)
	s.NoError(err)
	s.activeReqCache.Add(fpReq)

	// Add a all tags approved vuln request.
	deferralReq2 := fixtures.GetImageScopeDeferralRequest(scope.Registry, scope.Remote, scope.Tag, cve2)
	deferralReq2.Status = storage.RequestStatus_APPROVED
	err = s.vulnReqDataStore.AddRequest(allAccessCtx, deferralReq2)
	s.NoError(err)
	s.activeReqCache.Add(deferralReq2)

	// Verify that vuln request for exact tag is returned.
	for _, cve := range []string{cve1, cve2} {
		actualReq, err := s.mgr.EffectiveVulnReq(allAccessCtx, cve, scope)
		s.NoError(err)
		if cve == cve1 {
			s.Equal(deferralReq1.GetId(), actualReq.GetId())
		} else {
			s.Equal(deferralReq2.GetId(), actualReq.GetId())
		}
	}

	// Verify that vuln request for all tags is returned for image satisfying match all tags regex.
	cloned := scope
	cloned.Tag = "latest-latest"
	actualReq, err := s.mgr.EffectiveVulnReq(allAccessCtx, cve1, cloned)
	s.NoError(err)
	s.Equal(fpReq.GetId(), actualReq.GetId())

	// Verify that no vuln request for because no vul req exists for the cve.
	actualReq, err = s.mgr.EffectiveVulnReq(allAccessCtx, "fake", scope)
	s.NoError(err)
	s.Nil(actualReq)

	// Verify that no vuln request for because no vul req exists for the scope.
	cloned = scope
	cloned.Remote = "invalid"
	actualReq, err = s.mgr.EffectiveVulnReq(allAccessCtx, cve1, cloned)
	s.NoError(err)
	s.Nil(actualReq)

	// Verify that the state is deferred for scope which has deferral requests.
	cvesWithState, err := s.mgr.VulnsWithState(allAccessCtx, scope)
	s.NoError(err)
	for cve, state := range cvesWithState {
		if cve == cve1 || cve == cve2 {
			s.Equal(storage.VulnerabilityState_DEFERRED, state)
		} else {
			s.Equal(storage.VulnerabilityState_OBSERVED, state)
		}
	}

	// Verify that the state is deferred for scope which has deferral requests.
	cloned = scope
	cloned.Tag = "latest-latest"
	cvesWithState, err = s.mgr.VulnsWithState(allAccessCtx, cloned)
	s.NoError(err)
	for cve, state := range cvesWithState {
		if cve == cve1 || cve == cve2 {
			s.Equal(storage.VulnerabilityState_FALSE_POSITIVE, state)
		} else {
			s.Equal(storage.VulnerabilityState_OBSERVED, state)
		}
	}
}

func (s *VulnReqQueryManagerTestSuite) TestQueriesForActiveAndPendingRequests() {
	scope := common.VulnReqScope{
		Registry: s.testImage.GetName().GetRegistry(),
		Remote:   s.testImage.GetName().GetRemote(),
		Tag:      s.testImage.GetName().GetTag(),
	}
	cve1 := s.testImage.GetScan().GetComponents()[0].GetVulns()[0].GetCve()
	cve2 := s.testImage.GetScan().GetComponents()[0].GetVulns()[1].GetCve()

	// Add a pending and approved vuln request for same cve.
	deferralReq1 := fixtures.GetImageScopeDeferralRequest(scope.Registry, scope.Remote, scope.Tag, cve1)
	deferralReq1.Status = storage.RequestStatus_PENDING
	err := s.vulnReqDataStore.AddRequest(allAccessCtx, deferralReq1)
	s.NoError(err)
	s.pendingReqCache.Add(deferralReq1)
	fpReq := fixtures.GetImageScopeFPRequest(scope.Registry, scope.Remote, scope.Tag, cve1)
	fpReq.Status = storage.RequestStatus_APPROVED
	err = s.vulnReqDataStore.AddRequest(allAccessCtx, fpReq)
	s.NoError(err)
	s.activeReqCache.Add(fpReq)

	// Add a pending vuln request for same scope but different cve.
	deferralReq2 := fixtures.GetImageScopeDeferralRequest(scope.Registry, scope.Remote, scope.Tag, cve2)
	deferralReq2.Status = storage.RequestStatus_PENDING
	err = s.vulnReqDataStore.AddRequest(allAccessCtx, deferralReq2)
	s.NoError(err)
	s.pendingReqCache.Add(deferralReq2)

	// Verify that approved vuln request when present.
	for _, cve := range []string{cve1, cve2} {
		actualReq, err := s.mgr.EffectiveVulnReq(allAccessCtx, cve, scope)
		s.NoError(err)
		if cve == cve1 {
			s.Equal(fpReq.GetId(), actualReq.GetId())
		} else {
			s.Equal(deferralReq2.GetId(), actualReq.GetId())
		}
	}

	// Verify that the cve1 state is false-positive although one request is in pending state.
	// Verify that the cve2 state is observed because there are no pending requests.
	cvesWithState, err := s.mgr.VulnsWithState(allAccessCtx, scope)
	s.NoError(err)
	for cve, state := range cvesWithState {
		if cve == cve1 {
			s.Equal(storage.VulnerabilityState_FALSE_POSITIVE, state)
		} else if cve == cve2 {
			s.Equal(storage.VulnerabilityState_OBSERVED, state)
		}
	}
}
